//===----------------------------------------------------------------------===//
//
//                         BusTub
//
// hash_join_executor.h
//
// Identification: src/include/execution/executors/hash_join_executor.h
//
// Copyright (c) 2015-2021, Carnegie Mellon University Database Group
//
//===----------------------------------------------------------------------===//

#pragma once

#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "common/util/hash_util.h"
#include "execution/executor_context.h"
#include "execution/executors/abstract_executor.h"
#include "execution/plans/hash_join_plan.h"
#include "storage/table/tuple.h"

namespace bustub {
struct JoinKey {
  std::vector<Value> join_cols_;

  auto operator==(const JoinKey &other) const -> bool {
    for (uint32_t i = 0; i < other.join_cols_.size(); i++) {
      if (join_cols_[i].CompareEquals(other.join_cols_[i]) != CmpBool::CmpTrue) {
        return false;
      }
    }
    return true;
  }

  auto ToString() const -> std::string {
    std::stringstream stream;
    for (const auto &join_col : join_cols_) {
      stream << join_col.ToString() << " ";
    }
    return stream.str();
  }
};

struct JoinValue {
  std::vector<Value> join_cols_;
  auto ToString() const -> std::string {
    std::stringstream stream;
    for (const auto &join_col : join_cols_) {
      stream << join_col.ToString() << " ";
    }
    return stream.str();
  }
};

}  // namespace bustub

namespace std {

/** Implements std::hash on JoinKey */
template <>
struct hash<bustub::JoinKey> {
  auto operator()(const bustub::JoinKey &join_key) const -> std::size_t {
    size_t curr_hash = 0;
    for (const auto &key : join_key.join_cols_) {
      if (!key.IsNull()) {
        curr_hash = bustub::HashUtil::CombineHashes(curr_hash, bustub::HashUtil::HashValue(&key));
      }
    }
    return curr_hash;
  }
};

}  // namespace std

namespace bustub {

class HashJoinTable {
 public:
  void InsertCombine(const JoinKey &join_key, const JoinValue &join_val) {
    LOG_DEBUG("join_key=%s, join_val=%s", join_key.ToString().c_str(), join_val.ToString().c_str());
    ht_.emplace(join_key, join_val);
  }

  /**
   * Clear the hash table
   */
  void Clear() { ht_.clear(); }

  auto Count(const JoinKey &join_key) -> int { return ht_.count(join_key); }

 public:
  std::unordered_multimap<JoinKey, JoinValue> ht_;
};

/**
 * HashJoinExecutor executes a nested-loop JOIN on two tables.
 */
class HashJoinExecutor : public AbstractExecutor {
 public:
  /**
   * Construct a new HashJoinExecutor instance.
   * @param exec_ctx The executor context
   * @param plan The HashJoin join plan to be executed
   * @param left_child The child executor that produces tuples for the left side of join
   * @param right_child The child executor that produces tuples for the right side of join
   */
  HashJoinExecutor(ExecutorContext *exec_ctx, const HashJoinPlanNode *plan,
                   std::unique_ptr<AbstractExecutor> &&left_child, std::unique_ptr<AbstractExecutor> &&right_child);

  /** Initialize the join */
  void Init() override;

  /**
   * Yield the next tuple from the join.
   * @param[out] tuple The next tuple produced by the join.
   * @param[out] rid The next tuple RID, not used by hash join.
   * @return `true` if a tuple was produced, `false` if there are no more tuples.
   */
  auto Next(Tuple *tuple, RID *rid) -> bool override;

  auto GenForInnerJoin() -> void;

  auto GenForLeftJoin() -> void;

  /** @return The output schema for the join */
  auto GetOutputSchema() const -> const Schema & override { return plan_->OutputSchema(); };

 private:
  auto MakeRightJoinKey(const Tuple *tuple) -> JoinKey {
    std::vector<Value> keys;
    for (const auto &expr : plan_->RightJoinKeyExpressions()) {
      auto v = expr->Evaluate(tuple, right_child_->GetOutputSchema());
      keys.emplace_back(v);
    }
    return {keys};
  }

  auto MakeLeftJoinKey(const Tuple *tuple) -> JoinKey {
    std::vector<Value> keys;
    for (const auto &expr : plan_->LeftJoinKeyExpressions()) {
      auto v = expr->Evaluate(tuple, left_child_->GetOutputSchema());
      // LOG_DEBUG("v=%s", v.ToString().c_str());
      keys.emplace_back(v);
    }
    return {keys};
  }

  auto MakeJoinValue(const Tuple *tuple) -> JoinValue {
    std::vector<Value> vals;
    int cols = right_child_->GetOutputSchema().GetColumnCount();
    for (int i = 0; i < cols; i++) {
      auto v = tuple->GetValue(&right_child_->GetOutputSchema(), i);
      vals.push_back(v);
    }

    return {vals};
  }

 private:
  /** The NestedLoopJoin plan node to be executed. */
  const HashJoinPlanNode *plan_;
  std::unique_ptr<AbstractExecutor> left_child_;
  std::unique_ptr<AbstractExecutor> right_child_;
  HashJoinTable hjt_;
  std::vector<Tuple> res_;
  size_t cursor_{0};
};

}  // namespace bustub
